In [1]:
# run "pip installipython-autotime" in your conda env
%load_ext autotime
print('hello world')
hello world time: 242 μs (started: 2025-11-03 21:13:18 +01:00)
In [2]:
# Import bayesDREAM with reload capability
import importlib
import sys
import os
import torch
from pathlib import Path
# Add the directory containing 'bayesDREAM' to sys.path
base_path = Path('/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/') # Adjust relative path from your notebook
sys.path.append(str(base_path))
# Now import the model
import bayesDREAM
# Reload only if the module was already imported
importlib.reload(bayesDREAM)
from bayesDREAM import bayesDREAM
time: 1min 57s (started: 2025-11-03 21:13:18 +01:00)
In [3]:
import pandas as pd
time: 167 μs (started: 2025-11-03 21:15:15 +01:00)
In [4]:
deviceno = 0
device = torch.device(f'cuda:{deviceno}' if torch.cuda.is_available() else 'cpu')
time: 53.5 ms (started: 2025-11-03 21:15:15 +01:00)
Load data¶
In [5]:
data_dir = '/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/'
cell_meta = pd.read_csv(data_dir + '10X_SR_cell_meta.csv')
gene_meta = pd.read_csv(data_dir + '10X_SR_gene_meta.csv')
gene_counts = pd.read_csv(data_dir + '10X_SR_counts.csv', index_col=None)
gene_counts.index = gene_meta['Symbol'].values
gene_meta.index = gene_meta['Symbol'].values
gene_meta = gene_meta.rename(columns={'ID': 'gene_id', 'Symbol': 'gene_name'})
cell_meta['cell'] = cell_meta['Barcode']
cell_meta.loc[cell_meta['target'] == 'NTC', 'target'] = 'ntc'
time: 2.56 s (started: 2025-11-03 21:15:15 +01:00)
Create bayesDREAM objects¶
In [6]:
model = {}
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
model[cg] = bayesDREAM(
meta=cell_meta,
counts=gene_counts,
gene_meta=gene_meta,
cis_gene=cg,
output_dir=data_dir+'bayesDREAM/output/',
sum_factor_col = 'clustered.sum.factor',
label='20251030_' + cg,
device = device
)
[INFO] Extracting 'cis' modality from gene 'GFI1B' [INFO] Creating 'gene' modality with trans genes (excluding 'GFI1B') [VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid [INFO] Using 'gene_name' column as 'gene' identifier [INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1884. This may impact downstream analysis. warnings.warn(
[INIT] bayesDREAM core: label=20251030_GFI1B, device=cpu
[INFO] Subsetting modalities to 1884 cells from filtered meta
[INFO] Subsetting modality 'cis' from 2600 to 1884 cells
[INFO] Subsetting modality 'gene' from 2600 to 1884 cells
[INIT] bayesDREAM: 2 modalities loaded
- cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1884})
- gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1884})
[INFO] Extracting 'cis' modality from gene 'GEMIN5'
[INFO] Creating 'gene' modality with trans genes (excluding 'GEMIN5')
[VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid
[INFO] Using 'gene_name' column as 'gene' identifier
[INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
[INIT] bayesDREAM core: label=20251030_GEMIN5, device=cpu
[INFO] Subsetting modalities to 1879 cells from filtered meta
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1879. This may impact downstream analysis. warnings.warn(
[INFO] Subsetting modality 'cis' from 2600 to 1879 cells
[INFO] Subsetting modality 'gene' from 2600 to 1879 cells
[INIT] bayesDREAM: 2 modalities loaded
- cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1879})
- gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1879})
[INFO] Extracting 'cis' modality from gene 'DDX6'
[INFO] Creating 'gene' modality with trans genes (excluding 'DDX6')
[VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid
[INFO] Using 'gene_name' column as 'gene' identifier
[INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
[INIT] bayesDREAM core: label=20251030_DDX6, device=cpu
[INFO] Subsetting modalities to 1819 cells from filtered meta
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1819. This may impact downstream analysis. warnings.warn(
[INFO] Subsetting modality 'cis' from 2600 to 1819 cells
[INFO] Subsetting modality 'gene' from 2600 to 1819 cells
[INIT] bayesDREAM: 2 modalities loaded
- cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1819})
- gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1819})
time: 1.82 s (started: 2025-11-03 21:15:18 +01:00)
In [7]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
model[cg].adjust_ntc_sum_factor(sum_factor_col_old='clustered.sum.factor')
[INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment. [INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment. [INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment. time: 108 ms (started: 2025-11-03 21:15:20 +01:00)
In [8]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
# --- Define consistent palettes ---
palette = {
'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)], # GFI1B_[1-3]
'NTC': [cm.Greys(i) for i in np.linspace(0.4, 0.8, 5)], # NTC_[1-5]
'GEMIN5':[cm.Blues(i) for i in np.linspace(0.4, 0.8, 2)], # GEMIN5_[1-2]
'DDX6': [cm.Reds(i) for i in np.linspace(0.4, 0.8, 3)], # DDX6_[1,3]
}
# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
for i, color in enumerate(colors, start=1):
guide_colors[f"{gene}_{i}"] = color
# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
axes = [axes]
for ax, cg in zip(axes, cgs):
df = model[cg].meta.copy()
df = df[(df['clustered.sum.factor'] > 0) & (df['sum_factor_adj'] > 0)]
for guide, sub in df.groupby('guide'):
color = guide_colors.get(guide, 'black')
ax.scatter(
sub['clustered.sum.factor'],
sub['sum_factor_adj'],
s=12,
alpha=0.8,
color=color,
label=guide,
)
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_title(cg)
ax.set_xlabel('clustered.sum.factor (log₂)')
ax.set_ylabel('sum_factor_adj (log₂)')
ax.grid(True, linewidth=0.5, alpha=0.4)
ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
plt.tight_layout()
plt.show()
time: 1.38 s (started: 2025-11-03 21:15:20 +01:00)
Fit cis¶
In [9]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
print(cg)
print(model[cg].list_modalities())
print()
GFI1B name distribution n_features n_cells 0 cis negbinom 1 1884 1 gene negbinom 13791 1884 GEMIN5 name distribution n_features n_cells 0 cis negbinom 1 1879 1 gene negbinom 13791 1879 DDX6 name distribution n_features n_cells 0 cis negbinom 1 1819 1 gene negbinom 13791 1819 time: 19.2 ms (started: 2025-11-03 21:15:21 +01:00)
fit¶
In [10]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
# --- CIS FIT: Load if exists, otherwise fit and save ---
cis_fit_path = os.path.join(model[cg].output_dir, model[cg].label, 'x_true.pt')
if os.path.exists(cis_fit_path):
print("[INFO] Loading existing cis fit...")
model[cg].load_cis_fit()
else:
print("[INFO] Running cis fit (this may take a while)...")
model[cg].fit_cis(sum_factor_col="sum_factor_adj", tolerance=0, niters=100000)
model[cg].save_cis_fit()
[INFO] Loading existing cis fit... [LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/x_true.pt [LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B [INFO] Loading existing cis fit... [LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/x_true.pt [LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5 [INFO] Loading existing cis fit... [LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/x_true.pt [LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6 time: 772 ms (started: 2025-11-03 21:15:21 +01:00)
Plot results¶
In [11]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.stats import gaussian_kde
# ----------------------------
# Palette and color utilities
# ----------------------------
palette = {
'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)], # GFI1B_[1-3]
'NTC': [cm.Greys(i) for i in np.linspace(0.4, 0.8, 5)], # NTC_[1-5]
'GEMIN5':[cm.Blues(i) for i in np.linspace(0.4, 0.8, 2)], # GEMIN5_[1-2]
'DDX6': [cm.Reds(i) for i in np.linspace(0.4, 0.8, 3)], # DDX6_[1,3]
}
def build_guide_colors(palette_dict):
guide_colors = {}
for gene, colors in palette_dict.items():
for i, color in enumerate(colors, start=1):
guide_colors[f"{gene}_{i}"] = color
return guide_colors
guide_colors = build_guide_colors(palette)
# Optional: target color map (violin plots)
target_colors = {
'GFI1B': cm.Greens(0.7),
'NTC': cm.Greys(0.6),
'GEMIN5':cm.Blues(0.7),
'DDX6': cm.Reds(0.7),
'ntc': cm.Greys(0.6), # if you've normalized NTC to lowercase
}
# ----------------------------
# Helpers
# ----------------------------
def to_np(a):
"""Safely convert torch/array-like to numpy."""
try:
import torch
if isinstance(a, torch.Tensor):
return a.detach().cpu().numpy()
except Exception:
pass
return np.asarray(a)
def per_cell_mean_std(x):
"""Compute per-cell mean and std along axis 0 (samples x cells)."""
x_np = to_np(x)
return x_np.mean(axis=0), x_np.std(axis=0)
# ----------------------------
# 1) Scatter plots colored by guide
# ----------------------------
def scatter_by_guide(model, cg, log2=False):
df = model[cg].meta.copy()
X = to_np(model[cg].x_true)
if log2:
# filter strictly positive before log
mask_pos = (X > 0).all(axis=0)
X = X[:, mask_pos]
df = df.loc[mask_pos].reset_index(drop=True)
X = np.log2(X)
x_mean, x_std = X.mean(axis=0), X.std(axis=0)
plt.figure(figsize=(6, 5))
for guide, subidx in df.groupby('guide').groups.items():
color = guide_colors.get(guide, 'black')
plt.scatter(x_mean[subidx], x_std[subidx], s=14, alpha=0.8, color=color, label=guide)
plt.xlabel('mean x_true' + (' (log2)' if log2 else ''))
plt.ylabel('std x_true' + (' (log2)' if log2 else ''))
plt.title(f'{cg}: mean vs std of x_true' + (' (log2)' if log2 else ''))
plt.grid(True, linewidth=0.5, alpha=0.4)
plt.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
plt.tight_layout()
plt.show()
# Usage (your two plots):
# raw scale
# scatter_by_guide(model, cg, log2=False)
# log2 scale
# scatter_by_guide(model, cg, log2=True)
def scatter_ci95_by_guide(model, cg, log2=False, full_width=False):
"""
Scatter of per-cell mean vs 95% CI width (or half-width) of x_true samples.
- x: mean over samples
- y: CI_95 width (q97.5 - q2.5) if full_width=True,
else half-width = 0.5 * (q97.5 - q2.5)
Colors points by model[cg].meta['guide'] using guide_colors.
"""
df = model[cg].meta.copy()
X = to_np(model[cg].x_true) # shape [S, N] (samples x cells)
if log2:
# Keep only cells strictly positive across samples before log2
mask_pos = (X > 0).all(axis=0)
X = X[:, mask_pos]
df = df.loc[mask_pos].reset_index(drop=True)
X = np.log2(X)
x_mean = X.mean(axis=0)
q_lo = np.percentile(X, 2.5, axis=0)
q_hi = np.percentile(X, 97.5, axis=0)
ci_w = (q_hi - q_lo)
y_val = ci_w if full_width else 0.5 * ci_w # half-width by default
plt.figure(figsize=(6, 5))
for guide, idx in df.groupby('guide').groups.items():
color = guide_colors.get(guide, 'black')
plt.scatter(x_mean[idx], y_val[idx], s=14, alpha=0.85, color=color, label=guide)
plt.xlabel('mean x_true' + (' (log2)' if log2 else ''))
ylabel = '95% CI ' + ('width' if full_width else 'half-width')
ylabel += ' of x_true' + (' (log2)' if log2 else '')
plt.ylabel(ylabel)
plt.title(f'{cg}: mean vs 95% CI of x_true' + (' (log2)' if log2 else ''))
plt.grid(True, linewidth=0.5, alpha=0.4)
plt.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
plt.tight_layout()
plt.show()
# ----------------------------------------------------------
# 2) Violin: x-axis = guide, color = target, x_true on log2
# ----------------------------------------------------------
def violin_by_guide_log2(model, cg):
df = model[cg].meta.copy()
X = to_np(model[cg].x_true)
# Keep only cells with strictly positive across samples before log2
pos_mask = (X > 0).all(axis=0)
X = X[:, pos_mask]
df = df.loc[pos_mask].reset_index(drop=True)
Xlog = np.log2(X)
x_cell_mean = Xlog.mean(axis=0)
df = df.assign(x_true_mean_log2=x_cell_mean)
# Order guides nicely
guide_order = sorted(df['guide'].astype(str).unique(),
key=lambda g: (g.split('_')[0], int(g.split('_')[1]) if '_' in g and g.split('_')[1].isdigit() else 0))
data = [df.loc[df['guide'] == g, 'x_true_mean_log2'].values for g in guide_order]
# Color each violin by its guide's target
colors = []
for g in guide_order:
tvals = df.loc[df['guide'] == g, 'target'].astype(str).unique()
t = tvals[0] if len(tvals) else 'NTC'
colors.append(target_colors.get(t, 'gray'))
plt.figure(figsize=(max(6, 1.2*len(guide_order)), 4.8))
parts = plt.violinplot(data, showmeans=True, showextrema=False)
for body, c in zip(parts['bodies'], colors):
body.set_facecolor(c)
body.set_edgecolor('black')
body.set_alpha(0.85)
# mean line style
if 'cmeans' in parts:
parts['cmeans'].set_color('black')
parts['cmeans'].set_linewidth(1.0)
plt.xticks(ticks=np.arange(1, len(guide_order)+1), labels=guide_order, rotation=45, ha='right')
plt.xlabel('guide')
plt.ylabel('x_true mean (log₂)')
plt.title(f'{cg}: x_true (log₂) by guide (colored by target)')
plt.grid(True, linewidth=0.5, alpha=0.4, axis='y')
plt.tight_layout()
plt.show()
# ----------------------------------------------------------
# 3) Density (KDE): filled, color = guide, x_true on log2
# KDE over per-cell mean of x_true (log2)
# ----------------------------------------------------------
def filled_density_by_guide_log2(model, cg, bw=None):
df = model[cg].meta.copy()
X = to_np(model[cg].x_true)
pos_mask = (X > 0).all(axis=0)
X = X[:, pos_mask]
df = df.loc[pos_mask].reset_index(drop=True)
Xlog = np.log2(X)
x_cell_mean = Xlog.mean(axis=0)
df = df.assign(x_true_mean_log2=x_cell_mean)
# global x-range for all guides
xmin, xmax = np.percentile(x_cell_mean, [0.5, 99.5])
xs = np.linspace(xmin, xmax, 400)
plt.figure(figsize=(7, 4.8))
# Keep legend order stable
guides = sorted(df['guide'].astype(str).unique(),
key=lambda g: (g.split('_')[0], int(g.split('_')[1]) if '_' in g and g.split('_')[1].isdigit() else 0))
for g in guides:
vals = df.loc[df['guide'] == g, 'x_true_mean_log2'].values
color = guide_colors.get(g, 'black')
if len(np.unique(vals)) < 2:
# Not enough variance for KDE; plot a small filled bump
y = np.exp(-0.5*((xs - vals[0]) / 0.01)**2) # tiny Gaussian bump
plt.fill_between(xs, 0*y, y, color=color, alpha=0.35, label=g)
continue
kde = gaussian_kde(vals, bw_method=bw)
ys = kde(xs)
plt.fill_between(xs, 0, ys, color=color, alpha=0.35, label=g)
plt.plot(xs, ys, color=color, linewidth=1.5)
plt.xlabel('x_true mean (log₂)')
plt.ylabel('density')
plt.title(f'{cg}: filled density by guide (log₂)')
plt.grid(True, linewidth=0.5, alpha=0.4)
plt.legend(title='guide', fontsize=8, frameon=False, ncol=2)
plt.tight_layout()
plt.show()
time: 5.68 ms (started: 2025-11-03 21:15:22 +01:00)
In [12]:
# -----------------------------------
# Example calls for your three genes
# -----------------------------------
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
# Your two scatter plots:
scatter_by_guide(model, cg, log2=False)
scatter_by_guide(model, cg, log2=True)
# Your two scatter plots:
scatter_ci95_by_guide(model, cg, log2=False, full_width=True)
scatter_ci95_by_guide(model, cg, log2=True, full_width=True)
# Violin by target:
violin_by_guide_log2(model, cg) # or log2=True if you prefer
# Density by guide:
filled_density_by_guide_log2(model, cg) # or log2=True for log2 scale
time: 4.26 s (started: 2025-11-03 21:15:22 +01:00)
In [13]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
model[cg].set_technical_groups(['Sample'])
[INFO] Set technical_group_code with 1 groups based on ['Sample'] [INFO] Set technical_group_code with 1 groups based on ['Sample'] [INFO] Set technical_group_code with 1 groups based on ['Sample'] time: 5.06 ms (started: 2025-11-03 21:15:26 +01:00)
In [14]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
tgs = ['MYB', 'HES4', 'GAPDH']
for tg in tgs:
for cg in cgs:
model[cg].plot_xy_data(tg, window=100, sum_factor_col='sum_factor_adj', show_correction='uncorrected');
time: 2.04 s (started: 2025-11-03 21:15:26 +01:00)
In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# --- Prepare data ---
counts = model['GEMIN5'].counts
meta = model['GEMIN5'].meta.copy()
# Ensure meta is indexed by cell IDs so it can align to counts.columns
if 'cell' in meta.columns and not meta.index.equals(counts.columns):
meta = meta.set_index('cell')
# Build an aligned mask for the counts' columns
mask = meta['target'].reindex(counts.columns).eq('ntc').fillna(False)
# Extract counts for GEMIN5 in NTC cells
gemin5_ntc = counts.loc['GEMIN5', mask]
# --- Plot ---
plt.figure(figsize=(6, 4))
counts_hist, bins, patches = plt.hist(
gemin5_ntc,
bins=30,
color="#4C72B0",
edgecolor="black",
alpha=0.8
)
# --- Add percentage labels above each bar ---
total = counts_hist.sum()
for count, bin_left, bin_right in zip(counts_hist, bins[:-1], bins[1:]):
if count > 0:
percent = 100 * count / total
plt.text(
(bin_left + bin_right) / 2,
count,
f"{percent:.1f}%",
ha="center",
va="bottom",
fontsize=8,
rotation=0
)
# --- Styling ---
plt.title("GEMIN5 Expression in NTC Cells", fontsize=14, weight="bold", pad=15)
plt.xlabel("Counts", fontsize=12)
plt.ylabel("Number of Cells", fontsize=12)
plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
sns.despine(trim=True)
plt.tight_layout()
plt.show()
time: 131 ms (started: 2025-11-03 21:15:28 +01:00)
In [16]:
np.mean(gemin5_ntc > 0)
Out[16]:
np.float64(0.39168343393695504)
time: 1.48 ms (started: 2025-11-03 21:15:29 +01:00)
Fit trans¶
In [17]:
# remove hacky tech group added
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
model[cg].meta.drop('technical_group_code', axis=1, inplace=True)
time: 6.82 ms (started: 2025-11-03 21:15:29 +01:00)
Refit sumfactor¶
In [18]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
model[cg].refit_sumfactor(sum_factor_col_old='clustered.sum.factor')
[INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment. [INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment. [INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment. time: 731 ms (started: 2025-11-03 21:15:29 +01:00)
In [19]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
# --- Define consistent palettes ---
palette = {
'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)], # GFI1B_[1-3]
'NTC': [cm.Greys(i) for i in np.linspace(0.4, 0.8, 5)], # NTC_[1-5]
'GEMIN5':[cm.Blues(i) for i in np.linspace(0.4, 0.8, 2)], # GEMIN5_[1-2]
'DDX6': [cm.Reds(i) for i in np.linspace(0.4, 0.8, 3)], # DDX6_[1,3]
}
# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
for i, color in enumerate(colors, start=1):
guide_colors[f"{gene}_{i}"] = color
# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
axes = [axes]
for ax, cg in zip(axes, cgs):
df = model[cg].meta.copy()
df = df[(df['clustered.sum.factor'] > 0) & (df['sum_factor_new'] > 0)]
for guide, sub in df.groupby('guide'):
color = guide_colors.get(guide, 'black')
ax.scatter(
sub['clustered.sum.factor'],
sub['sum_factor_new'],
s=12,
alpha=0.8,
color=color,
label=guide,
)
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_title(cg)
ax.set_xlabel('clustered.sum.factor (log₂)')
ax.set_ylabel('sum_factor_new (log₂)')
ax.grid(True, linewidth=0.5, alpha=0.4)
ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
plt.tight_layout()
plt.show()
time: 678 ms (started: 2025-11-03 21:15:29 +01:00)
In [20]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
# --- Define consistent palettes ---
palette = {
'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)], # GFI1B_[1-3]
'NTC': [cm.Greys(i) for i in np.linspace(0.4, 0.8, 5)], # NTC_[1-5]
'GEMIN5':[cm.Blues(i) for i in np.linspace(0.4, 0.8, 2)], # GEMIN5_[1-2]
'DDX6': [cm.Reds(i) for i in np.linspace(0.4, 0.8, 3)], # DDX6_[1,3]
}
# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
for i, color in enumerate(colors, start=1):
guide_colors[f"{gene}_{i}"] = color
# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
axes = [axes]
for ax, cg in zip(axes, cgs):
df = model[cg].meta.copy()
df = df[(df['sum_factor_adj'] > 0) & (df['sum_factor_new'] > 0)]
for guide, sub in df.groupby('guide'):
color = guide_colors.get(guide, 'black')
ax.scatter(
sub['sum_factor_adj'],
sub['sum_factor_new'],
s=12,
alpha=0.8,
color=color,
label=guide,
)
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_title(cg)
ax.set_xlabel('sum_factor_adj (log₂)')
ax.set_ylabel('sum_factor_new (log₂)')
ax.grid(True, linewidth=0.5, alpha=0.4)
ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
plt.tight_layout()
plt.show()
time: 814 ms (started: 2025-11-03 21:15:30 +01:00)
Fit¶
In [21]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
# --- TRANS FIT: Load if exists, otherwise fit and save ---
trans_fit_path = os.path.join(model[cg].output_dir, model[cg].label, 'posterior_samples_trans_gene.pt')
if os.path.exists(trans_fit_path):
print("[INFO] Loading existing trans fit...")
model[cg].load_trans_fit()
else:
print("[INFO] Running trans fit (this may take a while)...")
model[cg].fit_trans(sum_factor_col="sum_factor_new", tolerance=0)
model[cg].save_trans_fit()
[INFO] Loading existing trans fit... [LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/posterior_samples_trans_gene.pt [LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/posterior_samples_trans_gene.pt [LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B [LOAD] Modalities loaded: ['cis', 'gene'] [INFO] Loading existing trans fit... [LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/posterior_samples_trans_gene.pt [LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/posterior_samples_trans_gene.pt [LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5 [LOAD] Modalities loaded: ['cis', 'gene'] [INFO] Loading existing trans fit... [LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/posterior_samples_trans_gene.pt [LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/posterior_samples_trans_gene.pt [LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6 [LOAD] Modalities loaded: ['cis', 'gene'] time: 2.5 s (started: 2025-11-03 21:15:31 +01:00)
Plot results¶
Mean v CI plots¶
In [22]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
samples_alpha = model[cg].posterior_samples_trans['K_a'][:, 0, :].cpu().numpy()
samples_n = model[cg].posterior_samples_trans['n_a'][:, 0, :].cpu().numpy()
# Mean values
alpha_mean = samples_alpha.mean(axis=0)
n_mean = samples_n.mean(axis=0)
# 95% credible intervals
alpha_lo, alpha_hi = np.percentile(samples_alpha, [2.5, 97.5], axis=0)
n_lo, n_hi = np.percentile(samples_n, [2.5, 97.5], axis=0)
# 95% CI widths
alpha_ci_width = alpha_hi - alpha_lo
n_ci_width = n_hi - n_lo
# Dependency mask based on n's CI excluding 0
dependent_mask = (n_lo > 0) | (n_hi < 0)
dependent_pct = dependent_mask.mean() * 100
# === Plot 1: K_a ===
plt.figure()
plt.scatter(alpha_mean[~dependent_mask], alpha_ci_width[~dependent_mask],
s=5, alpha=0.4, color='black', label='not dependent')
plt.scatter(alpha_mean[dependent_mask], alpha_ci_width[dependent_mask],
s=5, alpha=0.6, color='blue', label='dependent')
plt.xlabel(r'Mean $K_a$')
plt.ylabel(r'95% CI width of $K_a$')
plt.title(f"{cg} — $K_a$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
plt.axhline(0, color='black', linestyle=':', linewidth=1)
plt.axvline(0, color='black', linestyle=':', linewidth=1)
plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
plt.tight_layout()
plt.show()
# === Plot 2: n_a ===
plt.figure()
plt.scatter(n_mean[~dependent_mask], n_ci_width[~dependent_mask],
s=5, alpha=0.4, color='black', label='not dependent')
plt.scatter(n_mean[dependent_mask], n_ci_width[dependent_mask],
s=5, alpha=0.6, color='blue', label='dependent')
plt.xlabel(r'Mean $n$ (Hill coefficient)')
plt.ylabel(r'95% CI width of $n$')
plt.title(f"{cg} — $n_a$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
plt.axhline(0, color='black', linestyle=':', linewidth=1)
plt.axvline(0, color='black', linestyle=':', linewidth=1)
plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
plt.tight_layout()
plt.show()
time: 2.72 s (started: 2025-11-03 20:43:05 +01:00)
In [23]:
def hill_xinf_samples(K_samps, n_samps, tol=0.2, x_max=None):
S, T = n_samps.shape
xinf = np.full((S, T), np.nan, dtype=float)
m = np.abs(n_samps)
base = (m - 1.0) / (m + 1.0) # in (0,1) when |n|>1
valid = m > (1.0 + tol)
with np.errstate(divide='ignore', invalid='ignore'):
log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
xinf[valid] = np.exp(log_xinf[valid])
if x_max is not None:
xinf[xinf > x_max] = np.nan
return xinf
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
n_mean = n_samps.mean(axis=0)
n_lo, n_hi = np.percentile(n_samps, [2.5, 97.5], axis=0)
n_ci_width = n_hi - n_lo
dependent_mask = (n_lo > 0) | (n_hi < 0)
dependent_pct = dependent_mask.mean() * 100
# --- Inflection plot, color by NaN fraction (dark = few NaNs) ---
xinf_samps = hill_xinf_samples(K_samps, n_samps, tol=0.2, x_max=None)
xinf_mean = np.nanmean(xinf_samps, axis=0)
xinf_lo, xinf_hi = np.nanpercentile(xinf_samps, [2.5, 97.5], axis=0)
xinf_ci_width = xinf_hi - xinf_lo
frac_nan = np.mean(np.isnan(xinf_samps), axis=0)
mask = dependent_mask & np.isfinite(xinf_mean) & np.isfinite(xinf_ci_width)
plt.figure()
sc = plt.scatter(
xinf_mean[mask], xinf_ci_width[mask],
c=frac_nan[mask], cmap='Blues_r', vmin=0, vmax=1,
s=8, alpha=0.9
)
plt.xlabel(r'Mean inflection $x_{\mathrm{inf}}$')
plt.ylabel(r'95% CI width of $x_{\mathrm{inf}}$')
plt.title(f"{cg} — inflection (dependent only; {dependent_pct:.1f}% dependent)")
cbar = plt.colorbar(sc, pad=0.02)
cbar.set_label('fraction NaN in $x_{\\mathrm{inf}}$ samples (NaN are where abs(n)<1)')
plt.axhline(0, color='black', linestyle=':', linewidth=1)
plt.tight_layout()
plt.show()
# --- n plot, gene-level points with alpha blending ---
plt.figure()
# not dependent (light grey)
plt.scatter(
n_mean[~dependent_mask], n_ci_width[~dependent_mask],
s=5, alpha=0.3, color='grey', label='not dependent'
)
# dependent (blue, darker = overlap)
plt.scatter(
n_mean[dependent_mask], n_ci_width[dependent_mask],
s=5, alpha=0.2, color='blue', label='dependent'
)
plt.xlabel(r'Mean $n$ (Hill coefficient)')
plt.ylabel(r'95% CI width of $n$')
plt.title(f"{cg} — $n$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
plt.axhline(0, color='black', linestyle=':', linewidth=1)
plt.tight_layout()
plt.show()
/tmp/ipykernel_2603555/1393261059.py:9: RuntimeWarning: overflow encountered in divide log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base) /tmp/ipykernel_2603555/1393261059.py:28: RuntimeWarning: Mean of empty slice xinf_mean = np.nanmean(xinf_samps, axis=0) /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/software/miniconda3/envs/pyroenv/lib/python3.12/site-packages/numpy/lib/_nanfunctions_impl.py:1650: RuntimeWarning: All-NaN slice encountered return fnb._ureduce(a,
time: 5.16 s (started: 2025-11-03 20:43:07 +01:00)
Posterior density lines plots¶
Gene-level parameters¶
In [24]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from scipy.stats import gaussian_kde
from matplotlib import cm
def plot_posterior_density_lines(
samples, # [S, T]
title="Posterior density lines",
sort_by="median",
subset_mask=None,
cmap="viridis",
alpha_overall=0.5,
density_gamma=0.7,
norm_global=True,
y_quantiles=(0.5, 99.5),
grid_points=350,
linewidth=0.8,
add_median_lines=True,
y_label=r"$\theta$",
ax=None,
show=True,
y_range=None, # <-- NEW
):
"""Plot per-feature posterior densities as vertical colour lines."""
samples = np.asarray(samples)
if samples.ndim == 1:
samples = samples[:, None]
elif samples.ndim > 2:
samples = samples.reshape(samples.shape[0], -1)
S, T = samples.shape
if subset_mask is not None:
subset_mask = np.asarray(subset_mask, dtype=bool)
samples = samples[:, subset_mask]
S, T = samples.shape
if sort_by == "median":
order = np.argsort(np.nanmedian(samples, axis=0))
elif sort_by == "mean":
order = np.argsort(np.nanmean(samples, axis=0))
else:
order = np.arange(T)
samples_sorted = samples[:, order]
# --- y-range: either from samples, or overridden explicitly ---
if y_range is None:
y_min, y_max = np.nanpercentile(samples_sorted, y_quantiles)
else:
y_min, y_max = y_range
y_grid = np.linspace(y_min, y_max, grid_points)
# KDE per feature
dens_list = []
for t in range(T):
vals = samples_sorted[:, t]
vals = vals[~np.isnan(vals)]
if vals.size < 2:
dens = np.zeros_like(y_grid)
else:
kde = gaussian_kde(vals)
dens = kde(y_grid)
dens_list.append(dens)
dens_mat = np.stack(dens_list, axis=0)
if norm_global:
m = dens_mat.max() + 1e-12
dens_norm = (dens_mat / m) ** density_gamma
else:
m = dens_mat.max(axis=1, keepdims=True) + 1e-12
dens_norm = (dens_mat / m) ** density_gamma
L = len(y_grid)
segs_all, cols_all = [], []
cmap_obj = cm.get_cmap(cmap)
for x_pos in range(T):
x = np.full(L, x_pos, dtype=float)
pts = np.column_stack([x, y_grid])
segs = np.stack([pts[:-1], pts[1:]], axis=1)
c = cmap_obj(dens_norm[x_pos, :-1])
c[:, 3] = alpha_overall
segs_all.append(segs)
cols_all.append(c)
segs_all = np.concatenate(segs_all, axis=0)
cols_all = np.concatenate(cols_all, axis=0)
if ax is None:
fig, ax = plt.subplots(figsize=(12, 4))
else:
fig = ax.figure
lc = LineCollection(segs_all, colors=cols_all, linewidths=linewidth)
ax.add_collection(lc)
if add_median_lines:
med = np.nanmedian(samples_sorted, axis=0)
for i, m_val in enumerate(med):
arr = np.asarray(m_val)
if arr.size == 0:
continue
m_val_scalar = float(arr.ravel()[0])
if not np.isfinite(m_val_scalar):
continue
ax.hlines(
m_val_scalar,
i - 0.4,
i + 0.4,
color="white",
linewidth=0.5,
alpha=0.9,
)
ax.set_xlim(-0.5, T - 0.5)
ax.set_ylim(y_min, y_max)
ax.axhline(0, color='black', linestyle=':', linewidth=1)
ax.set_xlabel("Genes (equal-width bins)")
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
ax.set_xticks([])
if show:
fig.tight_layout()
plt.show()
return ax
def dependency_mask_from_n(n_samps, ci=95.0):
"""95% CI of n excludes 0."""
lo_q = (100 - ci) / 2.0
hi_q = 100 - lo_q
lo, hi = np.percentile(n_samps, [lo_q, hi_q], axis=0)
return (lo > 0) | (hi < 0)
def abs_n_gt_tol_mask(n_samps, tol=1.0):
"""|median(n)| > 1+tol_n_for_xinf (tol is the extra beyond 1)."""
med_abs = np.abs(np.median(n_samps, axis=0))
return med_abs > (1.0 + tol)
def hill_xinf_samples(K_samps, n_samps, tol_n=0.0):
"""
Compute per-sample x_inf for Hill curves.
K_samps, n_samps: [S, T]
Returns [S, T] with NaN where |n| <= 1+tol_n.
"""
m = np.abs(n_samps)
base = (m - 1.0) / (m + 1.0)
with np.errstate(divide='ignore', invalid='ignore'):
log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
xinf = np.exp(log_xinf)
xinf[m <= (1.0 + tol_n)] = np.nan
return xinf
def hill_y(x, A, alpha, K, n, eps=1e-8):
"""
Vectorized Hill function for arrays (broadcasting OK).
A, alpha, K, n can be [S, T]; x can be scalar or array.
"""
x = np.asarray(x, dtype=float)
# ensure broadcasting: add trailing axes if needed
while A.ndim > x.ndim:
x = np.expand_dims(x, axis=0)
x_n = np.power(x, n)
K_n = np.power(K, n)
y = A + alpha * x_n / (K_n + x_n + eps)
return y
import numpy as np
def compute_log2fc_metrics(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, eps=1e-6, n_zero_tol=1e-6):
"""
Directional log2 fold-change metrics for:
y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
log2fc_full: between y(x→∞) and y(x→0), with sign determined by n.
log2fc_obs: between y(x_max_obs) and y(x_min_obs).
"""
# ensure arrays
A_samps = np.asarray(A_samps)
alpha_samps = np.asarray(alpha_samps)
Vmax_samps = np.asarray(Vmax_samps)
K_samps = np.asarray(K_samps)
n_samps = np.asarray(n_samps)
# --- observed x range from mean x_true across samples per cell ---
X = np.asarray(x_true_samps) # [S, N_cells]
x_means_per_cell = X.mean(axis=0) # [N_cells]
x_min_obs = float(x_means_per_cell.min())
x_max_obs = float(x_means_per_cell.max())
A = A_samps
alpha = alpha_samps
Vmax = Vmax_samps
# ---------------- full-range FC with n sign ----------------
# sign of n determines whether y increases or decreases with x
n_sign = np.sign(n_samps)
# treat near-zero n as flat (no direction)
flat_mask = np.abs(n_samps) < n_zero_tol
n_sign[flat_mask] = 0.0
# asymptotes:
# if n > 0: y(0) = A, y(∞) = A + alpha*Vmax
# if n < 0: y(0) = A + alpha*Vmax, y(∞) = A
y0_full = np.where(n_sign >= 0, A, A + alpha * Vmax)
yinf_full = np.where(n_sign >= 0, A + alpha * Vmax, A)
log2fc_full = np.zeros_like(A, dtype=float)
changing_mask = n_sign != 0.0
log2fc_full[changing_mask] = np.log2(
(yinf_full[changing_mask] + eps) /
(y0_full[changing_mask] + eps)
)
# flat_mask stays at 0
# ---------------- helper: y(x) under this parametrisation ----------------
def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
"""
y(x) = A + alpha * Vmax * x^n / (K^n + x^n), evaluated at scalar x.
"""
x = float(x_scalar)
x_safe = x + eps_inner
K_safe = K + eps_inner
with np.errstate(divide='ignore', invalid='ignore'):
x_log = np.log(x_safe)
K_log = np.log(K_safe)
x_n = np.exp(n * x_log)
K_n = np.exp(n * K_log)
frac = x_n / (K_n + x_n + eps_inner)
h = Vmax * frac
return A + alpha * h
# ---------------- observed-range FC: x_min_obs -> x_max_obs ----------------
Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))
return log2fc_full, log2fc_obs, x_min_obs, x_max_obs
time: 4.68 ms (started: 2025-11-03 20:43:13 +01:00)
In [25]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
dep_mask = dependency_mask_from_n(n_samps)
plot_posterior_density_lines(
n_samps,
title=f"{cg} — posterior of $n$",
subset_mask=dep_mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r"$n$ (Hill coefficient)"
)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
time: 2min 7s (started: 2025-11-03 20:43:13 +01:00)
In [26]:
tol_n_for_xinf = 0.2 # extra margin beyond 1
def log2_pos(a):
a = np.asarray(a)
out = np.full_like(a, np.nan, dtype=float)
mask = a > 0
out[mask] = np.log2(a[mask])
return out
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
dep_mask = dependency_mask_from_n(n_samps)
abs_gt_tol = abs_n_gt_tol_mask(n_samps, tol=tol_n_for_xinf)
mask = dep_mask & abs_gt_tol
xinf_samps = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf)
log2_xinf_samps = log2_pos(xinf_samps)
# also drop genes where xinf is NaN for all samples
mask &= ~np.all(np.isnan(xinf_samps), axis=0)
plot_posterior_density_lines(
log2_xinf_samps,
title=f"{cg} — posterior of $x_{{\\mathrm{{inf}}}}$ (dependent, |n|>1+{tol_n_for_xinf})",
subset_mask=mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r'$\log_2 x_{\mathrm{inf}}$'
)
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base) /tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
time: 1min 54s (started: 2025-11-03 20:45:21 +01:00)
In [27]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
dep_mask = dependency_mask_from_n(n_samps)
log2_K_samps = log2_pos(K_samps)
plot_posterior_density_lines(
log2_K_samps,
title=f"{cg} — posterior of $K_a$ (dependent only)",
subset_mask=dep_mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r"$\log_2 K_a$"
)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
time: 2min (started: 2025-11-03 20:47:15 +01:00)
In [28]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
# --- Extract posterior samples ---
A_samps = model[cg].posterior_samples_trans['A'][:, 0, :].cpu().numpy()
alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].cpu().numpy()
Vmax_samps = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].cpu().numpy()
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].cpu().numpy()
x_true_samps = model[cg].x_true.detach().cpu().numpy()
# --- Compute log2FC metrics ---
log2fc_full, log2fc_obs, x_min_obs, x_max_obs = compute_log2fc_metrics(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
)
# --- Gene-level means ---
full_mean = np.nanmean(log2fc_full, axis=0)
obs_mean = np.nanmean(log2fc_obs, axis=0)
# --- Dependency mask (95% CI of n excludes 0) ---
lo, hi = np.percentile(n_samps, [2.5, 97.5], axis=0)
dep_mask = (lo > 0) | (hi < 0)
# ============================================================
# 1️⃣ Correlation between full-range and observed log2FC
# ============================================================
plt.figure(figsize=(5.5, 5))
plt.scatter(full_mean[~dep_mask], obs_mean[~dep_mask],
s=10, alpha=0.25, color='grey', label='not dependent')
plt.scatter(full_mean[dep_mask], obs_mean[dep_mask],
s=10, alpha=0.6, color='blue', label='dependent')
# 1:1 line
lim_min = min(full_mean.min(), obs_mean.min())
lim_max = max(full_mean.max(), obs_mean.max())
plt.plot([lim_min, lim_max], [lim_min, lim_max],
color='black', linestyle=':', linewidth=1)
plt.xlabel(r'log$_2$ Fold-Change (full dynamic range: $A \rightarrow A + \alpha V_{\max}$)')
plt.ylabel(r'log$_2$ Fold-Change (within observed $x_{\min} \rightarrow x_{\max}$)')
plt.title(f"{cg}: Relationship between full and observed dynamic range")
plt.legend(frameon=False, loc='best')
plt.grid(True, linewidth=0.5, alpha=0.4)
plt.tight_layout()
plt.show()
# ============================================================
# 2️⃣ Distribution of full-range log₂FC
# ============================================================
plot_posterior_density_lines(
log2fc_full,
title=fr"{cg} — Posterior distribution of log$_2$ Fold-Change (full dynamic range)",
subset_mask=dep_mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r"log$_2$ Fold-Change (full dynamic range: $A \rightarrow A + \alpha V_{\max}$)",
)
# ============================================================
# 3️⃣ Distribution of observed-range log₂FC
# ============================================================
plot_posterior_density_lines(
log2fc_obs,
title=fr"{cg} — Posterior distribution of log$_2$ Fold-Change (within observed $x$-range)",
subset_mask=dep_mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r"log$_2$ Fold-Change (observed range: $x_{\min} \rightarrow x_{\max}$)",
)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
time: 4min 51s (started: 2025-11-03 20:49:16 +01:00)
Cell-level parameters¶
In [29]:
from matplotlib.lines import Line2D
from matplotlib.transforms import Bbox
def plot_xtrue_density_by_guide(
model,
cg,
log2=False,
cmap="viridis",
alpha_overall=0.5,
density_gamma=0.7,
norm_global=True,
y_quantiles=(0.5, 99.5),
grid_points=350,
linewidth=0.8,
group_by_guide=True,
):
"""
One vertical density line per *cell* for x_true, matching the style of
plot_posterior_density_lines, with guides indicated by:
- coloured horizontal median ticks per cell
- a coloured bar between title and axes showing guide per cell
- a legend mapping colour -> guide
group_by_guide:
True -> cells grouped by guide, then median within guide
False -> cells ordered only by median, but colours still show guide ID.
"""
df = model[cg].meta.copy()
X = to_np(model[cg].x_true) # [S, N_cells]
# log2 transform without dropping guides
if log2:
eps = 1e-6
X = np.log2(np.maximum(X, eps))
samples = np.asarray(X) # [S, N]
S, N = samples.shape
guides = df['guide'].astype(str).to_numpy() # length N
# ---------- choose ordering ----------
med_per_cell = np.nanmedian(samples, axis=0)
if group_by_guide:
# order cells: by guide, then median within guide
def guide_sort_key(g):
parts = g.split('_')
root = parts[0]
idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
return (root, idx)
unique_guides = sorted(np.unique(guides), key=guide_sort_key)
guide_block_rank = {g: i for i, g in enumerate(unique_guides)}
guide_ranks = np.array([guide_block_rank[g] for g in guides])
order = np.lexsort((med_per_cell, guide_ranks)) # (N,)
else:
unique_guides = sorted(np.unique(guides))
order = np.argsort(med_per_cell)
samples_sorted = samples[:, order]
guides_sorted = guides[order]
# ---------- draw density background using generic function ----------
ylabel = "x_true" + (" (log₂)" if log2 else "")
# no axes title (we'll use a figure-level title instead)
ax = plot_posterior_density_lines(
samples_sorted,
title="", # <-- important: no axis title
sort_by=None,
subset_mask=None,
cmap=cmap,
alpha_overall=alpha_overall,
density_gamma=density_gamma,
norm_global=norm_global,
y_quantiles=y_quantiles,
grid_points=grid_points,
linewidth=linewidth,
add_median_lines=False,
y_label=ylabel,
ax=None,
show=False, # we'll manage layout
)
fig = ax.figure
# First, lay out the main axes nicely
fig.tight_layout(rect=[0, 0, 0.98, 0.93]) # leave some top margin
# Now get the *final* axis position after tight_layout
from matplotlib.transforms import Bbox
pos = ax.get_position()
# ---------- coloured median ticks per cell (guide-coded) ----------
med = np.nanmedian(samples_sorted, axis=0)
for i, m_val in enumerate(med):
arr = np.asarray(m_val)
if arr.size == 0:
continue
m_val_scalar = float(arr.ravel()[0])
if not np.isfinite(m_val_scalar):
continue
g = str(guides_sorted[i])
tick_color = guide_colors.get(g, (1.0, 1.0, 1.0, 1.0))
ax.hlines(
m_val_scalar,
i - 0.4,
i + 0.4,
color=tick_color,
linewidth=0.8,
alpha=0.9,
linestyle="solid",
zorder=3,
)
ax.set_xlim(-0.5, N - 0.5)
ax.set_xlabel("Cells (grouped by guide)" if group_by_guide else "Cells (ordered by median)")
ax.axhline(0, color='black', linestyle=':', linewidth=1)
# ---------- coloured bar between title and axes ----------
bar_height_frac = 0.06 # ~6% of axis height
bar_gap_frac = 0.02 # small gap above axes
bar_bottom = pos.y1 + (pos.height * bar_gap_frac)
bar_top = bar_bottom + (pos.height * bar_height_frac)
bar_pos = Bbox.from_extents(pos.x0, bar_bottom, pos.x1, bar_top)
bar_ax = fig.add_axes(bar_pos)
bar_ax.set_xlim(-0.5, N - 0.5)
bar_ax.set_ylim(0, 1)
bar_ax.axis("off")
# contiguous runs of the same guide along x
start = 0
current = guides_sorted[0]
segments = []
for i in range(1, N):
if guides_sorted[i] != current:
segments.append((start, i - 1, current))
start = i
current = guides_sorted[i]
segments.append((start, N - 1, current))
for s, e, g in segments:
color = guide_colors.get(g, "black")
bar_ax.axvspan(s - 0.5, e + 0.5, color=color)
# ---------- legend ----------
handles = []
labels = []
for g in unique_guides:
color = guide_colors.get(g, 'black')
handles.append(Line2D([0], [0], color=color, lw=3))
labels.append(g)
ax.legend(handles, labels, title="guide", frameon=False,
bbox_to_anchor=(1.02, 0.5), loc="center left")
# figure-level title at the very top
fig.suptitle(f"{cg}: posterior of x_true per cell", y=0.99)
# no more tight_layout calls here
plt.show()
time: 3.02 ms (started: 2025-11-03 20:54:07 +01:00)
In [30]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
plot_xtrue_density_by_guide(model, cg, log2=True, group_by_guide=True)
plot_xtrue_density_by_guide(model, cg, log2=True, group_by_guide=False)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap)
time: 2min 11s (started: 2025-11-03 20:54:07 +01:00)
In [31]:
from scipy.stats import gaussian_kde
tol_n_for_xinf = 0.2 # extra margin beyond 1
def log2_pos(a):
a = np.asarray(a)
out = np.full_like(a, np.nan, dtype=float)
mask = a > 0
out[mask] = np.log2(a[mask])
return out
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
dep_mask = dependency_mask_from_n(n_samps)
abs_gt_tol = abs_n_gt_tol_mask(n_samps, tol=tol_n_for_xinf)
mask = dep_mask & abs_gt_tol
xinf_samps = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf)
log2_xinf_samps = log2_pos(xinf_samps)
# drop genes where x_inf is NaN for all samples
mask &= ~np.all(np.isnan(xinf_samps), axis=0)
# ---- compute log2 mean x_true per cell & its y-range ----
df_meta = model[cg].meta.copy()
x_true_samps = model[cg].x_true.detach().cpu().numpy() # [S, N_cells]
xtrue_mean_per_cell = x_true_samps.mean(axis=0) # [N_cells]
log2_xtrue_means = log2_pos(xtrue_mean_per_cell)
vals_all = log2_xtrue_means[~np.isnan(log2_xtrue_means)]
# y-range driven by x_true distribution (e.g. central 99% of cells)
y_min = np.percentile(vals_all, 0.5)
y_max = np.percentile(vals_all, 99.5)
y_range = (y_min, y_max)
# ---- global 95% CI of log2 x_inf (for reference) ----
inf_vals_all = log2_xinf_samps[:, mask]
inf_vals_all = inf_vals_all[~np.isnan(inf_vals_all)]
if inf_vals_all.size > 0:
ci_lo, ci_hi = np.percentile(inf_vals_all, [2.5, 97.5])
else:
ci_lo, ci_hi = y_min, y_max
# ---------------- figure with 2 panels: main + side density ----------------
fig, (ax_main, ax_side) = plt.subplots(
1, 2,
figsize=(8, 5),
gridspec_kw={"width_ratios": [4, 1], "wspace": 0.05},
sharey=True,
)
# ----- main posterior density of log2 x_inf -----
ax_main = plot_posterior_density_lines(
log2_xinf_samps,
title=rf"{cg} — posterior of $\log_2 x_{{\mathrm{{inf}}}}$ (dependent, |n|>1+{tol_n_for_xinf})",
subset_mask=mask,
cmap="viridis",
alpha_overall=0.45,
density_gamma=0.7,
add_median_lines=True,
y_label=r'$\log_2 x_{\mathrm{inf}}$',
ax=ax_main,
show=False,
y_range=y_range, # <-- use x_true-driven scale
)
ax_main.set_xlabel("Genes (equal-width bins)")
ax_main.set_ylim(y_min, y_max)
# left y ticks
ax_main.set_yticks(np.linspace(y_min, y_max, 5))
ax_main.yaxis.set_ticks_position('left')
ax_main.tick_params(axis='y', which='both', length=4)
# Optionally indicate global 95% CI region for x_inf
ax_main.axhline(ci_lo, color='white', linestyle=':', linewidth=0.7, alpha=0.7)
ax_main.axhline(ci_hi, color='white', linestyle=':', linewidth=0.7, alpha=0.7)
# ----- sideways densities of log2 mean x_true by target -----
ax_side.set_xlabel(r'density of $\log_2 x_{\mathrm{true}}$', fontsize=9)
ax_side.xaxis.set_label_position('top')
targets = df_meta['target'].astype(str).to_numpy()
uniq_targets = sorted(np.unique(targets))
y_grid = np.linspace(y_min, y_max, 400)
for t in uniq_targets:
mask_t = targets == t
vals_t = log2_xtrue_means[mask_t]
vals_t = vals_t[~np.isnan(vals_t)]
if vals_t.size == 0:
continue
color = target_colors.get(t, 'grey')
if vals_t.size < 2:
# tiny bump if no variance
y0 = vals_t[0]
bump = np.exp(-0.5 * ((y_grid - y0) / 0.05) ** 2)
bump /= bump.max() + 1e-12
ax_side.fill_betweenx(y_grid, 0, bump, color=color, alpha=0.45)
ax_side.plot(bump, y_grid, color=color, linewidth=1.0)
else:
kde = gaussian_kde(vals_t)
dens_t = kde(y_grid)
dens_t /= dens_t.max() + 1e-12
ax_side.fill_betweenx(y_grid, 0, dens_t, color=color, alpha=0.45)
ax_side.plot(dens_t, y_grid, color=color, linewidth=1.0)
ax_side.set_xlim(0, 1.05)
# mirror the y-axis on the right for sanity
ax_side.yaxis.set_ticks_position('right')
ax_side.yaxis.set_label_position('right')
ax_side.set_yticks(np.linspace(y_min, y_max, 5))
ax_side.tick_params(axis='y', which='both', length=4)
ax_side.set_ylabel("") # keep only left label to avoid clutter
fig.tight_layout()
plt.show()
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base) /tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap) /tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect. fig.tight_layout()
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base) /tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap) /tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect. fig.tight_layout()
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base) /tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. cmap_obj = cm.get_cmap(cmap) /tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect. fig.tight_layout()
time: 2min 15s (started: 2025-11-03 20:56:18 +01:00)
Mean results plots¶
In [32]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
# assumes:
# - compute_log2fc_metrics is already defined (directional)
# - dependency_mask_from_n is already defined
# - target_colors dict exists, e.g. {'NTC': ..., 'GFI1B': ..., ...}
# -----------------------------
# x_inf helper
# -----------------------------
def hill_xinf_samples(K_samps, n_samps, tol_n=0.0):
"""
Compute per-sample x_inf for Hill curves.
For the positive Hill function using (K, n), the point of inflexion is:
x_inf = K * ((|n|-1)/(|n|+1))^(1/n)
This only makes sense when |n| > 1 + tol_n. Otherwise we return NaN.
K_samps, n_samps: [S, T]
Returns:
xinf: [S, T] with NaN where |n| <= 1+tol_n.
"""
K_samps = np.asarray(K_samps)
n_samps = np.asarray(n_samps)
m = np.abs(n_samps)
base = (m - 1.0) / (m + 1.0)
with np.errstate(divide='ignore', invalid='ignore'):
log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
xinf = np.exp(log_xinf)
xinf[m <= (1.0 + tol_n)] = np.nan
return xinf
# -----------------------------
# generic scatter helper
# -----------------------------
def scatter_nice_dep_plot(
x_vals,
y_vals,
dep_mask,
cg,
xlabel,
ylabel,
title,
target_colors,
alpha=0.3,
s=5,
add_zero_guides=True,
):
"""
Make a dependency-coloured scatter with a square plotting area
and the legend placed fully outside (to the right).
"""
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
x_vals = np.asarray(x_vals)
y_vals = np.asarray(y_vals)
dep_mask = np.asarray(dep_mask, dtype=bool)
finite = np.isfinite(x_vals) & np.isfinite(y_vals)
if not np.any(finite):
print(f"[{cg}] No finite points for this plot.")
return
x = x_vals[finite]
y = y_vals[finite]
dep = dep_mask[finite]
nondep = ~dep
color_ntc = target_colors.get("NTC", cm.Greys(0.6))
color_cg = target_colors.get(cg, "blue")
# make a *square plot box*, leave right margin for legend
fig, ax = plt.subplots(figsize=(5.5, 5.5))
ax.set_box_aspect(1) # ensures square axes box
# scatter points
ax.scatter(
x[nondep],
y[nondep],
s=s,
alpha=alpha,
color=color_ntc,
label="non-dependent",
)
ax.scatter(
x[dep],
y[dep],
s=s,
alpha=alpha,
color=color_cg,
label=f"{cg} dependent",
)
# zero guides
if add_zero_guides:
ax.axhline(0, color="black", linestyle=":", linewidth=1)
ax.axvline(0, color="black", linestyle=":", linewidth=1)
# labels and title
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
# legend outside, keeping the plot perfectly square
leg = ax.legend(
frameon=False,
loc="center left",
bbox_to_anchor=(1.02, 0.5), # move outside right
borderaxespad=0.0,
)
for lh in leg.legend_handles:
try:
lh.set_sizes([50])
except Exception:
pass
# tidy layout — prevents label cutoff but doesn’t distort square axes
fig.subplots_adjust(right=0.78) # reserve space for legend
ax.grid(True, linewidth=0.5, alpha=0.4)
plt.show()
# -----------------------------
# main loop over cis genes
# -----------------------------
# -----------------------------
# main loop over cis genes
# -----------------------------
tol_n_for_xinf = 0.2 # extra margin beyond 1 for x_inf
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
print(f"\n=== {cg}: FC/n/x_inf relationships ===")
# posterior samples
A_samps = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
Vmax_samps = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
x_true_samps = model[cg].x_true.detach().cpu().numpy()
# directional log2FCs from the Hill model
log2fc_full, log2fc_obs, x_min_obs, x_max_obs = compute_log2fc_metrics(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
)
# per-gene means
n_mean = np.mean(n_samps, axis=0) # [T]
log2fc_full_mean = np.mean(log2fc_full, axis=0) # [T]
# dependency mask
dep_mask = dependency_mask_from_n(n_samps)
dep_pct = 100.0 * np.sum(dep_mask) / len(dep_mask)
# x_inf samples + mean, then log2-transform
xinf_samps = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf) # [S, T]
xinf_mean = np.nanmean(xinf_samps, axis=0) # [T]
log2_xinf_mean = np.log2(xinf_mean)
xinf_finite_mask = np.isfinite(log2_xinf_mean)
# -----------------------------
# 1) mean n vs mean full-range log2FC
# -----------------------------
scatter_nice_dep_plot(
x_vals=n_mean,
y_vals=log2fc_full_mean,
dep_mask=dep_mask,
cg=cg,
target_colors=target_colors,
xlabel=rf"mean $n$ (Hill coefficient)",
ylabel=rf"mean log$_2$FC ($y(x\to\infty)$ vs $y(x\to 0)$)",
title=f"{cg}: full-range log$_2$FC vs Hill coefficient\n({dep_pct:.1f}% dependent)",
)
# -----------------------------
# 2) mean n vs log2(x_inf)
# -----------------------------
dep_mask_xinf = dep_mask & xinf_finite_mask
scatter_nice_dep_plot(
x_vals=n_mean[xinf_finite_mask],
y_vals=log2_xinf_mean[xinf_finite_mask],
dep_mask=dep_mask_xinf[xinf_finite_mask],
cg=cg,
target_colors=target_colors,
xlabel=rf"mean $n$ (Hill coefficient)",
ylabel=rf"log$_2(x_{{\mathrm{{inf}}}})$",
title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs Hill coefficient\n(only |n|>1+{tol_n_for_xinf})",
)
# -----------------------------
# 3) mean full-range log2FC vs log2(x_inf)
# -----------------------------
scatter_nice_dep_plot(
x_vals=log2fc_full_mean[xinf_finite_mask],
y_vals=log2_xinf_mean[xinf_finite_mask],
dep_mask=dep_mask_xinf[xinf_finite_mask],
cg=cg,
target_colors=target_colors,
xlabel=rf"mean log$_2$FC ($y(x\to\infty)$ vs $y(x\to 0)$)",
ylabel=rf"log$_2(x_{{\mathrm{{inf}}}})$",
title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs full-range log$_2$FC\n(only |n|>1+{tol_n_for_xinf})",
)
=== GFI1B: FC/n/x_inf relationships ===
<>:195: SyntaxWarning: invalid escape sequence '\m'
<>:209: SyntaxWarning: invalid escape sequence '\m'
<>:195: SyntaxWarning: invalid escape sequence '\m'
<>:209: SyntaxWarning: invalid escape sequence '\m'
/tmp/ipykernel_2603555/2547606006.py:195: SyntaxWarning: invalid escape sequence '\m'
title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs Hill coefficient\n(only |n|>1+{tol_n_for_xinf})",
/tmp/ipykernel_2603555/2547606006.py:209: SyntaxWarning: invalid escape sequence '\m'
title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs full-range log$_2$FC\n(only |n|>1+{tol_n_for_xinf})",
/tmp/ipykernel_2603555/2547606006.py:32: RuntimeWarning: overflow encountered in divide
log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/2547606006.py:165: RuntimeWarning: Mean of empty slice
xinf_mean = np.nanmean(xinf_samps, axis=0) # [T]
=== GEMIN5: FC/n/x_inf relationships ===
=== DDX6: FC/n/x_inf relationships ===
time: 6.78 s (started: 2025-11-03 20:58:34 +01:00)
Compare to edgeR results¶
comparison plots¶
In [22]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
# --------------------------------------------------------------------
# Basic helpers
# --------------------------------------------------------------------
def lighten(color, amount=0.3):
"""Lighten an RGBA/RGB colour by mixing with white."""
c = np.array(mcolors.to_rgba(color))
white = np.array([1, 1, 1, 1])
return tuple((1 - amount) * c + amount * white)
def darken(color, amount=0.3):
"""Darken an RGBA/RGB colour by mixing with black."""
c = np.array(mcolors.to_rgba(color))
black = np.array([0, 0, 0, 1])
return tuple((1 - amount) * c + amount * black)
def dependency_mask_from_n(n_samps, ci=95.0):
"""
Dependency mask based on n: 95% CI excludes 0.
n_samps: [S, T]
"""
lo_q = (100 - ci) / 2.0
hi_q = 100 - lo_q
lo, hi = np.percentile(n_samps, [lo_q, hi_q], axis=0)
return (lo > 0) | (hi < 0)
# --------------------------------------------------------------------
# Hill-based log2FC metrics
# --------------------------------------------------------------------
def compute_log2fc_metrics(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, eps=1e-6):
"""
Compute *directional* log2 fold-change metrics for the Hill-based model:
y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
Parameters
----------
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps : [S, T]
Posterior samples for each parameter.
x_true_samps : [S, N_cells]
Posterior samples of x_true for this cis gene.
eps : float
Small constant for numerical stability.
Returns
-------
log2fc_full : [S, T]
Full-range log2FC: y(x→∞) vs y(x→0), *directional w.r.t. x increasing*.
- n > 0: log2( (A+αVmax) / A )
- n < 0: log2( A / (A+αVmax) )
log2fc_obs : [S, T]
Observed-range log2FC: y(x_max_obs) vs y(x_min_obs), directional in x.
x_min_obs, x_max_obs : float
Observed min/max of mean x_true across cells.
"""
# ensure arrays
A_samps = np.asarray(A_samps)
alpha_samps = np.asarray(alpha_samps)
Vmax_samps = np.asarray(Vmax_samps)
K_samps = np.asarray(K_samps)
n_samps = np.asarray(n_samps)
# observed x range from mean x_true per cell
X = np.asarray(x_true_samps) # [S, N_cells]
x_means_per_cell = X.mean(axis=0) # [N_cells]
x_min_obs = float(x_means_per_cell.min())
x_max_obs = float(x_means_per_cell.max())
A = A_samps
alpha = alpha_samps
Vmax = Vmax_samps
# --- FULL-RANGE --- #
# Asymptotes: A (low) and A + α·Vmax (high), but which is at x→0 vs x→∞
# depends on the sign of n.
y_low_val = A
y_high_val = A + alpha * Vmax
n_pos = (n_samps >= 0) # True where curve increases with x
# y(x→0) and y(x→∞) according to sign of n
y_at_x0 = np.where(n_pos, y_low_val, y_high_val)
y_at_xinf = np.where(n_pos, y_high_val, y_low_val)
log2fc_full = np.log2((y_at_xinf + eps) / (y_at_x0 + eps))
# --- helper: y(x) for observed-range FC --- #
def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
"""
y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
evaluated at scalar x, broadcasting over [S, T] params.
"""
x = float(x_scalar)
x_safe = x + eps_inner
K_safe = K + eps_inner
with np.errstate(divide='ignore', invalid='ignore'):
x_log = np.log(x_safe)
K_log = np.log(K_safe)
x_n = np.exp(n * x_log)
K_n = np.exp(n * K_log)
frac = x_n / (K_n + x_n + eps_inner)
h = Vmax * frac
return A + alpha * h
# --- OBSERVED-RANGE --- #
# y at min and max of empirical x range (direction is x_min→x_max)
Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))
return log2fc_full, log2fc_obs, x_min_obs, x_max_obs
def compute_log2fc_obs_for_cells(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, cell_mask, guide_labels=None, eps=1e-6
):
"""
Compute observed-range log2FC for each gene, given a subset of cells.
- A_samps, alpha_samps, Vmax_samps, K_samps, n_samps: [S, T]
- x_true_samps: [S, N_cells] (cis x_true samples)
- cell_mask: boolean [N_cells], selecting cells to *use* for x_min/x_max
(e.g. cells of a given guide + all NTC cells)
- guide_labels: array-like [N_cells] of guide IDs; if provided, we
first average x_true per guide, then take min/max
across guide means.
Returns:
log2fc_obs: [S, T] (per-sample, per-gene, directional)
x_min_obs, x_max_obs: floats (observed range in this subset)
"""
A_samps = np.asarray(A_samps)
alpha_samps = np.asarray(alpha_samps)
Vmax_samps = np.asarray(Vmax_samps)
K_samps = np.asarray(K_samps)
n_samps = np.asarray(n_samps)
X = np.asarray(x_true_samps)
# subset cells
X_sub = X[:, cell_mask] # [S, N_sub]
# per-cell means over posterior samples
x_means_per_cell = X_sub.mean(axis=0) # [N_sub]
# optionally aggregate per guide first
if guide_labels is not None:
guide_labels = np.asarray(guide_labels)
guides_sub = guide_labels[cell_mask] # [N_sub]
uniq_guides = np.unique(guides_sub)
perguide_means = []
for g in uniq_guides:
perguide_means.append(x_means_per_cell[guides_sub == g].mean())
perguide_means = np.array(perguide_means)
x_min_obs = float(perguide_means.min())
x_max_obs = float(perguide_means.max())
else:
x_min_obs = float(x_means_per_cell.min())
x_max_obs = float(x_means_per_cell.max())
def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
"""
y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
"""
x = float(x_scalar)
x_safe = x + eps_inner
K_safe = K + eps_inner
with np.errstate(divide='ignore', invalid='ignore'):
x_log = np.log(x_safe)
K_log = np.log(K_safe)
x_n = np.exp(n * x_log)
K_n = np.exp(n * K_log)
frac = x_n / (K_n + x_n + eps_inner)
h = Vmax * frac
return A + alpha * h
A = A_samps
alpha = alpha_samps
Vmax = Vmax_samps
Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))
return log2fc_obs, x_min_obs, x_max_obs
# --------------------------------------------------------------------
# Common scatter + heatmap plotting helper
# --------------------------------------------------------------------
def scatter_and_heatmap_edger_vs_bayes(
df_g,
y_col,
base_target_color,
base_ntc_color,
cg,
guide,
ylabel,
title_suffix,
fc_thresh=0.5,
flip_edger_x=True,
):
"""
For a single guide df_g (already subset to that guide), make:
- scatter plot: edgeR logFC vs bayesDREAM y_col
- 3x3 heatmap of category overlap
df_g must contain: 'logFC', 'ext_sig', 'dependent', y_col.
"""
g_str = str(guide)
# 4 colour classes for scatter (computed before finite mask)
colors = []
for dep, sig in zip(df_g['dependent'], df_g['ext_sig']):
if (not dep) and (not sig):
# neither method calls it
c = base_ntc_color
elif sig and (not dep):
# edgeR only (FDR<0.05, not dependent in bayesDREAM)
c = lighten(base_target_color, 0.4)
elif dep and (not sig):
# bayesDREAM only (dependent, FDR>=0.05)
c = base_target_color
else:
# both: dependent & FDR<0.05
c = darken(base_target_color, 0.4)
colors.append(c)
# restrict to finite values
finite = np.isfinite(df_g['logFC']) & np.isfinite(df_g[y_col])
df_plot = df_g[finite]
if df_plot.empty:
return
colors = np.array(colors)[finite.values]
# x vs y values
x_raw = df_plot['logFC'].values
x_vals = -x_raw if flip_edger_x else x_raw
y_vals = df_plot[y_col].values
# same scale on x & y
v_min = min(x_vals.min(), y_vals.min())
v_max = max(x_vals.max(), y_vals.max())
pad = 0.05 * (v_max - v_min + 1e-6)
x_lim = (v_min - pad, v_max + pad)
y_lim = (v_min - pad, v_max + pad)
# ---------- SCATTER ----------
plt.figure(figsize=(5.5, 5))
plt.scatter(
x_vals,
y_vals,
s=10,
c=colors,
alpha=0.8,
edgecolor='none',
)
# reference lines
plt.axhline(0, color='black', linestyle=':', linewidth=1)
plt.axvline(0, color='black', linestyle=':', linewidth=1)
plt.axhline( fc_thresh, color='black', linestyle='--', linewidth=0.8)
plt.axhline(-fc_thresh, color='black', linestyle='--', linewidth=0.8)
plt.axvline( fc_thresh, color='black', linestyle='--', linewidth=0.8)
plt.axvline(-fc_thresh, color='black', linestyle='--', linewidth=0.8)
plt.xlim(x_lim)
plt.ylim(y_lim)
x_label_prefix = '-' if flip_edger_x else ''
plt.xlabel(fr'{x_label_prefix}log$_2$FC (edgeR; per-guide)')
plt.ylabel(ylabel)
plt.title(f'{cg}, guide {g_str}: edgeR vs bayesDREAM {title_suffix}')
legend_handles = [
Line2D([0], [0], marker='o', linestyle='',
color=base_ntc_color, label='neither (edgeR nor bayesDREAM)'),
Line2D([0], [0], marker='o', linestyle='',
color=lighten(base_target_color, 0.4),
label='edgeR only (FDR<0.05)'),
Line2D([0], [0], marker='o', linestyle='',
color=base_target_color, label='bayesDREAM only (dependent)'),
Line2D([0], [0], marker='o', linestyle='',
color=darken(base_target_color, 0.4),
label='both (FDR<0.05 & dependent)'),
]
plt.legend(handles=legend_handles, frameon=False, fontsize=8, loc='best')
plt.grid(True, linewidth=0.5, alpha=0.4)
plt.tight_layout()
plt.show()
# ---------- HEATMAP ----------
# edgeR: 0 = not sig; 1 = sig & |log2FC| < thresh; 2 = sig & |log2FC| ≥ thresh
edge_cat = np.zeros(df_plot.shape[0], dtype=int)
edge_sig = df_plot['ext_sig'].values
edge_fc = df_plot['logFC'].values
edge_cat[ edge_sig & (np.abs(edge_fc) < fc_thresh) ] = 1
edge_cat[ edge_sig & (np.abs(edge_fc) >= fc_thresh) ] = 2
# bayesDREAM: 0 = not dependent; 1 = dep & small; 2 = dep & large
bayes_cat = np.zeros(df_plot.shape[0], dtype=int)
bayes_dep = df_plot['dependent'].values
bayes_fc = df_plot[y_col].values
bayes_cat[ bayes_dep & (np.abs(bayes_fc) < fc_thresh) ] = 1
bayes_cat[ bayes_dep & (np.abs(bayes_fc) >= fc_thresh) ] = 2
mat = np.zeros((3, 3), dtype=int)
for b, e in zip(bayes_cat, edge_cat):
mat[b, e] += 1
edge_labels = [
"not sig",
fr"sig, |log$_2$FC| < {fc_thresh}",
fr"sig, |log$_2$FC| ≥ {fc_thresh}",
]
bayes_labels = [
"not dependent",
fr"dep, |log$_2$FC| < {fc_thresh}",
fr"dep, |log$_2$FC| ≥ {fc_thresh}",
]
fig, ax_hm = plt.subplots(figsize=(5.2, 4.5))
im = ax_hm.imshow(mat, cmap="Blues")
total = mat.sum()
for i in range(3):
for j in range(3):
count = mat[i, j]
if total > 0:
pct = 100.0 * count / total
txt = f"{count}\n({pct:.1f}%)"
else:
txt = "0"
ax_hm.text(
j, i, txt,
ha="center", va="center",
color="black" if count < total/2 else "white",
fontsize=8,
)
ax_hm.set_xticks([0, 1, 2])
ax_hm.set_yticks([0, 1, 2])
ax_hm.set_xticklabels(edge_labels, rotation=30, ha="right")
ax_hm.set_yticklabels(bayes_labels)
ax_hm.set_xlabel("edgeR category")
ax_hm.set_ylabel("bayesDREAM category")
ax_hm.set_title(f"{cg}, guide {g_str}:\nedgeR vs bayesDREAM categories")
# no colourbar (values already annotated)
plt.tight_layout()
plt.show()
# --------------------------------------------------------------------
# Data preparation helper: shared for full & observed range
# --------------------------------------------------------------------
def prepare_de_for_cg(model, de_df, cg):
"""
Prepare model + edgeR data for a given cis gene cg.
Returns:
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps,
meta, de_cg (with idx, logFC, FDR, gene, n_mean, dependent, ext_sig),
base_target_color, base_ntc_color
"""
A_samps = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
Vmax_samps = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
x_true_samps = model[cg].x_true.detach().cpu().numpy()
meta = model[cg].meta
n_mean = n_samps.mean(axis=0) # [T]
dep_mask = dependency_mask_from_n(n_samps) # [T] bool
T = n_mean.shape[0]
# gene names aligned to posterior arrays
gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
if len(gene_names) != T:
print(f"[{cg}] WARNING: len(gene_names)={len(gene_names)} != T={T}. "
"Trimming gene_names to first T entries.")
gene_names = gene_names[:T]
gene_to_idx = {g: i for i, g in enumerate(gene_names)}
# external DE results (edgeR)
de_cg = de_df.copy()
de_cg = de_cg[de_cg['gene'].isin(gene_to_idx.keys())].copy()
if de_cg.empty:
print(f"[{cg}] No overlapping genes in edgeR results, skipping.")
return None
de_cg['idx'] = de_cg['gene'].map(gene_to_idx)
de_cg = de_cg[de_cg['idx'].notna()].copy()
de_cg['idx'] = de_cg['idx'].astype(int)
de_cg = de_cg[de_cg['idx'] < T].copy()
if de_cg.empty:
print(f"[{cg}] All overlapping edgeR genes had out-of-range indices, skipping.")
return None
idx_vals = de_cg['idx'].values
de_cg['n_mean'] = n_mean[idx_vals]
de_cg['dependent'] = dep_mask[idx_vals]
de_cg['ext_sig'] = de_cg['FDR'] < 0.05 # edgeR significance
base_target_color = target_colors.get(cg, 'blue')
base_ntc_color = target_colors.get('NTC', 'grey')
return (A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, meta, de_cg, base_target_color, base_ntc_color)
# --------------------------------------------------------------------
# 1) Full-range comparison: edgeR vs bayesDREAM log2FC_full (directional)
# --------------------------------------------------------------------
def plot_edger_vs_bayes_full_range(cg_list, model, de_df,
fc_thresh=0.5, flip_edger_x=True):
for cg in cg_list:
print(f"\n=== {cg} (full-range) ===")
prep = prepare_de_for_cg(model, de_df, cg)
if prep is None:
continue
(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, meta, de_cg, base_target_color, base_ntc_color) = prep
# full-range log2FC from bayesDREAM
log2fc_full, _, _, _ = compute_log2fc_metrics(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
)
log2fc_full_mean = log2fc_full.mean(axis=0) # [T]
de_cg['log2fc_full'] = log2fc_full_mean[de_cg['idx'].values]
guides_in_model = set(meta['guide'].astype(str).unique())
for guide in sorted(de_cg['guide'].unique()):
g_str = str(guide)
if g_str not in guides_in_model:
continue
df_g = de_cg[de_cg['guide'] == guide].copy()
if df_g.empty:
continue
scatter_and_heatmap_edger_vs_bayes(
df_g=df_g,
y_col='log2fc_full',
base_target_color=base_target_color,
base_ntc_color=base_ntc_color,
cg=cg,
guide=g_str,
ylabel=r'log$_2$FC$_{\mathrm{full}}$ '
r'(bayesDREAM, $x:0\to\infty$)',
title_suffix='(full-range log$_2$FC)',
fc_thresh=fc_thresh,
flip_edger_x=flip_edger_x,
)
# --------------------------------------------------------------------
# 2) Observed-range comparison: edgeR vs bayesDREAM log2FC_obs (guide+NTC)
# --------------------------------------------------------------------
def plot_edger_vs_bayes_observed_range(cg_list, model, de_df,
fc_thresh=0.5, flip_edger_x=True,
aggregate_by_guide=True):
for cg in cg_list:
print(f"\n=== {cg} (observed-range) ===")
prep = prepare_de_for_cg(model, de_df, cg)
if prep is None:
continue
(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps, meta, de_cg, base_target_color, base_ntc_color) = prep
guides_in_model = set(meta['guide'].astype(str).unique())
guide_labels_all = meta['guide'].astype(str).values
for guide in sorted(de_cg['guide'].unique()):
g_str = str(guide)
if g_str not in guides_in_model:
continue
df_g = de_cg[de_cg['guide'] == guide].copy()
if df_g.empty:
continue
# cells: this guide + NTC-target cells
guide_mask = guide_labels_all == g_str
ntc_mask = meta['target'].astype(str).str.upper().str.contains('NTC').to_numpy()
cell_mask = guide_mask | ntc_mask
log2fc_obs_guide, x_min_obs, x_max_obs = compute_log2fc_obs_for_cells(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
x_true_samps,
cell_mask=cell_mask,
guide_labels=guide_labels_all if aggregate_by_guide else None,
)
log2fc_obs_mean_guide = log2fc_obs_guide.mean(axis=0)
de_cg.loc[df_g.index, 'log2fc_obs_guide'] = \
log2fc_obs_mean_guide[df_g['idx'].values]
df_g = de_cg.loc[df_g.index].copy()
scatter_and_heatmap_edger_vs_bayes(
df_g=df_g,
y_col='log2fc_obs_guide',
base_target_color=base_target_color,
base_ntc_color=base_ntc_color,
cg=cg,
guide=g_str,
ylabel=r'log$_2$FC$_{\mathrm{obs}}$ '
r'(bayesDREAM; guide+NTC x-range)',
title_suffix='(observed log$_2$FC)',
fc_thresh=fc_thresh,
flip_edger_x=flip_edger_x,
)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.lines import Line2D
def compare_edger_extreme_per_target(model, de_df, cg, log2fc_thresh=0.5):
"""
For a cis gene `cg`, compare:
- x: log2FC from edgeR (guide with largest |log2FC| per gene)
- y: mean full-range log2FC from bayesDREAM (directional)
Returns:
dict with the 9 heatmap categories, each a list of gene names.
"""
# ======= bayesDREAM side =======
A_samps = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
Vmax_samps = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
x_true_samps = model[cg].x_true.detach().cpu().numpy()
meta = model[cg].meta
# directional full-range FC
log2fc_full, _, _, _ = compute_log2fc_metrics(
A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
)
log2fc_full_mean = log2fc_full.mean(axis=0)
dep_mask = dependency_mask_from_n(n_samps)
T = log2fc_full_mean.shape[0]
# map genes
gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
if len(gene_names) != T:
print(f"[{cg}] WARNING: trimming gene list from {len(gene_names)} to {T}")
gene_names = gene_names[:T]
gene_to_idx = {g: i for i, g in enumerate(gene_names)}
# ======= edgeR side =======
guides = meta['guide'].astype(str).values
targets = meta['target'].astype(str).values
target_guides = {g for g, t in zip(guides, targets) if t.upper() == cg.upper()}
de_cg = de_df[
de_df['guide'].astype(str).isin(target_guides)
& de_df['gene'].isin(gene_to_idx.keys())
].copy()
if de_cg.empty:
print(f"[{cg}] No overlapping edgeR rows for guides targeting {cg}")
return {}
# pick guide with largest |logFC|
idx_max = de_cg.groupby('gene')['logFC'].apply(lambda s: np.argmax(np.abs(s.values)))
extreme_rows = []
for gene, idx_local in idx_max.items():
sub = de_cg[de_cg['gene'] == gene]
if len(sub) > idx_local:
extreme_rows.append(sub.iloc[idx_local])
de_extreme = pd.DataFrame(extreme_rows)
# any guide significant?
sig_any = (
de_cg.groupby('gene')['FDR']
.apply(lambda s: np.any(s.values < 0.05))
.reset_index()
.rename(columns={'FDR': 'any_sig'})
)
agg = de_extreme.merge(sig_any, on='gene', how='left')
# map back to model
agg['idx'] = agg['gene'].map(gene_to_idx)
agg = agg[agg['idx'].notna()].copy()
agg['idx'] = agg['idx'].astype(int)
agg = agg[agg['idx'] < T].copy()
agg['log2fc_full'] = log2fc_full_mean[agg['idx'].values]
agg['dependent'] = dep_mask[agg['idx'].values]
# ======= extract data =======
x = - agg['logFC'].values
y = agg['log2fc_full'].values
dep = agg['dependent'].values
sig = agg['any_sig'].values
genes = agg['gene'].values
finite = np.isfinite(x) & np.isfinite(y)
x, y, dep, sig, genes = x[finite], y[finite], dep[finite], sig[finite], genes[finite]
base_target_color = target_colors.get(cg, 'blue')
base_ntc_color = target_colors.get('NTC', cm.Greys(0.6))
def lighten(c, amount=0.3):
arr = np.array(cm.colors.to_rgba(c))
return tuple((1 - amount) * arr + amount * np.array([1, 1, 1, 1]))
def darken(c, amount=0.3):
arr = np.array(cm.colors.to_rgba(c))
return tuple((1 - amount) * arr + amount * np.array([0, 0, 0, 1]))
# ======= plot: scatter =======
colors = []
for d, sgn in zip(dep, sig):
if (not d) and (not sgn):
c = base_ntc_color
elif sgn and (not d):
c = lighten(base_target_color, 0.4)
elif d and (not sgn):
c = base_target_color
else:
c = darken(base_target_color, 0.4)
colors.append(c)
fig, ax = plt.subplots(figsize=(5.5, 5.5))
ax.set_box_aspect(1)
ax.scatter(x, y, s=10, c=colors, alpha=0.8, edgecolor='none')
for t in [0, log2fc_thresh, -log2fc_thresh]:
ax.axhline(t, color='black', linestyle='--' if t != 0 else ':', linewidth=0.8)
ax.axvline(t, color='black', linestyle='--' if t != 0 else ':', linewidth=0.8)
pad = 0.1 * max(np.ptp(x), np.ptp(y))
ax.set_xlim(x.min() - pad, x.max() + pad)
ax.set_ylim(y.min() - pad, y.max() + pad)
ax.set_xlabel(r'-log$_2$FC (edgeR; guide with largest |log$_2$FC|)')
ax.set_ylabel(r'mean full-range log$_2$FC (bayesDREAM)')
ax.set_title(f"{cg}: edgeR (extreme guide) vs bayesDREAM log$_2$FC$_{{full}}$")
legend_handles = [
Line2D([0], [0], marker='o', linestyle='', color=base_ntc_color, label='neither'),
Line2D([0], [0], marker='o', linestyle='', color=lighten(base_target_color, 0.4), label='edgeR only'),
Line2D([0], [0], marker='o', linestyle='', color=base_target_color, label='bayesDREAM only'),
Line2D([0], [0], marker='o', linestyle='', color=darken(base_target_color, 0.4), label='both'),
]
leg = ax.legend(handles=legend_handles, frameon=False, loc="center left",
bbox_to_anchor=(1.02, 0.5), borderaxespad=0.0, fontsize=8)
for lh in leg.legend_handles:
try: lh.set_sizes([50])
except Exception: pass
fig.subplots_adjust(right=0.78)
ax.grid(True, linewidth=0.5, alpha=0.4)
plt.show()
# ======= confusion categories =======
edge_cat = np.zeros(x.shape[0], dtype=int)
edge_cat[sig & (np.abs(x) < log2fc_thresh)] = 1
edge_cat[sig & (np.abs(x) >= log2fc_thresh)] = 2
bayes_cat = np.zeros(y.shape[0], dtype=int)
bayes_cat[dep & (np.abs(y) < log2fc_thresh)] = 1
bayes_cat[dep & (np.abs(y) >= log2fc_thresh)] = 2
mat = np.zeros((3, 3), dtype=int)
cat_dict = {}
for b in range(3):
for e in range(3):
mask = (bayes_cat == b) & (edge_cat == e)
mat[b, e] = np.sum(mask)
key = f"bayes{b}_edge{e}"
cat_dict[key] = genes[mask].tolist()
# readable key mapping
readable = {
"bayes0_edge0": "notdep_notsig",
"bayes0_edge1": "notdep_sigsmall",
"bayes0_edge2": "notdep_siglarge",
"bayes1_edge0": "depsmall_notsig",
"bayes1_edge1": "depsmall_sigsmall",
"bayes1_edge2": "depsmall_siglarge",
"bayes2_edge0": "deplarge_notsig",
"bayes2_edge1": "deplarge_sigsmall",
"bayes2_edge2": "deplarge_siglarge",
}
gene_comp = {readable[k]: cat_dict[k] for k in readable}
# ======= heatmap =======
edge_labels = ["not sig", f"sig < {log2fc_thresh}", f"sig ≥ {log2fc_thresh}"]
bayes_labels = ["not dep", f"dep < {log2fc_thresh}", f"dep ≥ {log2fc_thresh}"]
fig_hm, ax_hm = plt.subplots(figsize=(5.2, 4.5))
im = ax_hm.imshow(mat, cmap="Blues")
total = mat.sum()
for i in range(3):
for j in range(3):
count = mat[i, j]
pct = 100.0 * count / total if total > 0 else 0
ax_hm.text(j, i, f"{count}\n({pct:.1f}%)", ha="center", va="center",
color="black" if count < total/2 else "white", fontsize=8)
ax_hm.set_xticks([0, 1, 2])
ax_hm.set_yticks([0, 1, 2])
ax_hm.set_xticklabels(edge_labels, rotation=30, ha="right")
ax_hm.set_yticklabels(bayes_labels)
ax_hm.set_xlabel("edgeR category")
ax_hm.set_ylabel("bayesDREAM category")
ax_hm.set_title(f"{cg}: edgeR vs bayesDREAM categories")
plt.tight_layout()
plt.show()
return gene_comp
time: 14.6 ms (started: 2025-11-03 21:15:33 +01:00)
In [23]:
# --------------------------------------------------------------------
# Load edgeR results and call the functions
# --------------------------------------------------------------------
de_df = pd.read_csv(
"/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/"
"data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/de_results_per_guide.csv"
)
# target_colors must exist already, e.g. your earlier map using cm.Greens, etc.
# Example call:
cis_genes = ['GFI1B', 'GEMIN5', 'DDX6']
# Full-range comparison
plot_edger_vs_bayes_full_range(cis_genes, model, de_df,
fc_thresh=0.5, flip_edger_x=True)
# Observed-range (guide+NTC) comparison
plot_edger_vs_bayes_observed_range(cis_genes, model, de_df,
fc_thresh=0.5, flip_edger_x=True,
aggregate_by_guide=True)
=== GFI1B (full-range) ===
=== GEMIN5 (full-range) ===
=== DDX6 (full-range) ===
=== GFI1B (observed-range) ===
=== GEMIN5 (observed-range) ===
=== DDX6 (observed-range) ===
time: 18.3 s (started: 2025-11-03 21:15:33 +01:00)
In [24]:
# === run and collect ===
gene_comp = {}
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
gene_comp[cg] = compare_edger_extreme_per_target(model, de_df, cg)
time: 1min 24s (started: 2025-11-03 21:15:52 +01:00)
In [25]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
def volcano_edger_per_guide(
model,
de_df,
cg,
fdr_thr=0.05,
logfc_thr=0.5,
n_top_label=10,
):
"""
Volcano plots per guide (edgeR), highlighting genes that are:
- not dependent in bayesDREAM
- FDR < fdr_thr
- |log2FC| > logfc_thr
Top n_top_label genes per guide (smallest FDR) are annotated.
"""
# ---- bayesDREAM dependency per gene for this cis gene ----
n_samps = model[cg].posterior_samples_trans["n_a"][:, 0, :].detach().cpu().numpy()
dep_mask = dependency_mask_from_n(n_samps) # [T] bool
T = dep_mask.shape[0]
gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
if len(gene_names) != T:
print(f"[{cg}] WARNING: len(gene_names)={len(gene_names)} != T={T}; trimming.")
gene_names = gene_names[:T]
gene_to_idx = {g: i for i, g in enumerate(gene_names)}
# guides for this model
guides_in_model = set(model[cg].meta["guide"].astype(str).unique())
base_target_color = target_colors.get(cg, "tab:blue")
highlight_color = base_target_color
background_color = "lightgrey"
# ---- loop over each guide present in edgeR & this model ----
for guide in sorted(de_df["guide"].astype(str).unique()):
g_str = str(guide)
if g_str not in guides_in_model:
continue
de_g = de_df[de_df["guide"].astype(str) == g_str].copy()
if de_g.empty:
continue
# map genes to bayesDREAM indices (where available)
de_g["idx"] = de_g["gene"].map(gene_to_idx)
# start as float column full of NaN
de_g["dependent"] = np.nan
mask_in_model = de_g["idx"].notna()
if mask_in_model.any():
idx_local = de_g.loc[mask_in_model, "idx"].astype(int).values
de_g.loc[mask_in_model, "dependent"] = dep_mask[idx_local].astype(float)
# numeric columns
logFC = de_g["logFC"].to_numpy(dtype=float)
FDR = de_g["FDR"].to_numpy(dtype=float)
# handle FDR=0 for -log10
FDR_safe = np.where(FDR <= 0, 1e-18, FDR)
neg_log10_FDR = -np.log10(FDR_safe)
de_g["neg_log10_FDR"] = neg_log10_FDR
# ---- highlight mask: non-dependent & sig & big effect ----
dep = de_g["dependent"].to_numpy(dtype=float) # float with NaNs
dep_known = ~np.isnan(dep)
not_dep = dep_known & (dep == 0.0)
sig = FDR < fdr_thr
big = np.abs(logFC) > logfc_thr
highlight = not_dep & sig & big
de_g["highlight"] = highlight
# ---- Volcano plot ----
fig, ax = plt.subplots(figsize=(5.5, 5.5))
# reference lines
ax.axhline(-np.log10(fdr_thr), color="black", linestyle="--", linewidth=0.8, alpha=0.5)
ax.axhline(-np.log10(1e-18), color="black", linestyle="--", linewidth=0.8, alpha=0.5)
ax.axvline( logfc_thr, color="black", linestyle="--", linewidth=0.8, alpha=0.5)
ax.axvline(-logfc_thr, color="black", linestyle="--", linewidth=0.8, alpha=0.5)
# background: all points
ax.scatter(
logFC,
neg_log10_FDR,
s=10,
alpha=0.8,
color=background_color,
edgecolor="none",
)
# highlighted points
if highlight.any():
ax.scatter(
logFC[highlight],
neg_log10_FDR[highlight],
s=10,
alpha=0.8,
color=highlight_color,
edgecolor="none",
linewidth=0.3,
)
ax.set_xlabel(r"log$_2$FC (edgeR)")
ax.set_ylabel(r"$- \log_{10}(\mathrm{FDR})$")
ax.set_title(
f"{cg}, guide {g_str}: edgeR volcano\n"
f"(non-dependent & sig & |log2FC|>{logfc_thr} highlighted)"
)
# label top n_top_label genes by smallest FDR
de_g_sorted = de_g.sort_values("FDR", ascending=True)
to_label = de_g_sorted.iloc[:n_top_label].copy()
for _, row in to_label.iterrows():
x0 = float(row["logFC"])
y0 = float(row["neg_log10_FDR"])
gene_name = row["gene"]
ax.text(
x0,
y0,
gene_name,
fontsize=7,
ha="center",
va="bottom",
alpha=0.9,
)
# no legend
ax.grid(True, linewidth=0.5, alpha=0.4)
plt.tight_layout()
plt.show()
time: 3.82 ms (started: 2025-11-03 21:17:16 +01:00)
assess problematic genes¶
In [26]:
for cg in ["GFI1B", "GEMIN5", "DDX6"]:
volcano_edger_per_guide(model, de_df, cg, fdr_thr=0.05, logfc_thr=0.5, n_top_label=0)
time: 2.46 s (started: 2025-11-03 21:17:16 +01:00)
In [50]:
problematic_genes = de_df[(de_df['guide'].str.contains('GFI1B')) & \
(de_df['gene'].isin(gene_comp['GFI1B']['notdep_siglarge'])) &\
(de_df['FDR']<0.05) &\
((de_df['logFC']).abs()>0.5)
]
plt.hist(problematic_genes.value_counts('gene'))
plt.ylabel('number of genes')
plt.xlabel('number of guides in which they are significant')
plt.show()
triple_problems = problematic_genes.value_counts('gene').index[problematic_genes.value_counts('gene') == 3]
print(f'Genes problematic in 3 guides: {triple_problems}')
Genes problematic in 3 guides: Index(['AREG', 'CLC', 'CRYBG2', 'CCL3', 'MT1G', 'STC1', 'MMP1'], dtype='object', name='gene') time: 110 ms (started: 2025-11-03 21:24:05 +01:00)
In [53]:
problematic_genes[problematic_genes['gene'].isin(triple_problems)]
Out[53]:
| logFC | logCPM | F | PValue | FDR | gene | guide | |
|---|---|---|---|---|---|---|---|
| 28238 | -1.838350 | 5.991872 | 18.909756 | 1.393749e-05 | 2.934746e-04 | STC1 | GFI1B_1 |
| 28416 | -1.799965 | 5.802117 | 14.828829 | 1.195865e-04 | 1.979997e-03 | CCL3 | GFI1B_1 |
| 28435 | -1.795806 | 5.914566 | 14.526167 | 1.408513e-04 | 2.280073e-03 | CLC | GFI1B_1 |
| 28499 | -1.489763 | 5.778868 | 13.742414 | 2.135175e-04 | 3.214884e-03 | MT1G | GFI1B_1 |
| 28955 | -0.892110 | 5.792663 | 9.541958 | 2.021793e-03 | 2.032403e-02 | CRYBG2 | GFI1B_1 |
| 29148 | -1.049437 | 5.842847 | 8.339231 | 3.899421e-03 | 3.436474e-02 | AREG | GFI1B_1 |
| 29304 | -1.033280 | 5.782471 | 7.655913 | 5.714823e-03 | 4.579828e-02 | MMP1 | GFI1B_1 |
| 41381 | 3.522624 | 5.914566 | 373.628880 | 0.000000e+00 | 0.000000e+00 | CLC | GFI1B_2 |
| 42041 | -1.943067 | 5.802117 | 24.822402 | 6.555345e-07 | 1.357527e-05 | CCL3 | GFI1B_2 |
| 42404 | -1.289393 | 5.842847 | 16.812098 | 4.204179e-05 | 5.634990e-04 | AREG | GFI1B_2 |
| 42483 | -1.385856 | 5.778868 | 15.758315 | 7.368590e-05 | 9.165598e-04 | MT1G | GFI1B_2 |
| 43111 | -1.033280 | 5.782471 | 10.694745 | 1.094091e-03 | 8.692227e-03 | MMP1 | GFI1B_2 |
| 43471 | -0.921560 | 5.991872 | 8.781050 | 3.056014e-03 | 2.010904e-02 | STC1 | GFI1B_2 |
| 43637 | -0.687874 | 5.792663 | 8.173919 | 4.271267e-03 | 2.604302e-02 | CRYBG2 | GFI1B_2 |
| 55195 | -1.938152 | 5.914566 | 47.425777 | 6.858900e-12 | 3.378506e-09 | CLC | GFI1B_3 |
| 55197 | -1.891080 | 5.802117 | 44.746268 | 2.557740e-11 | 1.175879e-08 | CCL3 | GFI1B_3 |
| 55282 | -1.201129 | 5.778868 | 22.627513 | 2.060848e-06 | 2.471584e-04 | MT1G | GFI1B_3 |
| 55358 | -0.981557 | 5.782471 | 17.414753 | 3.143810e-05 | 2.270127e-03 | MMP1 | GFI1B_3 |
| 55436 | -0.660349 | 5.792663 | 14.509890 | 1.414780e-04 | 7.253773e-03 | CRYBG2 | GFI1B_3 |
| 55610 | -0.703603 | 5.991872 | 10.591538 | 1.142519e-03 | 3.557025e-02 | STC1 | GFI1B_3 |
| 55651 | -0.683465 | 5.842847 | 10.228842 | 1.392780e-03 | 3.963096e-02 | AREG | GFI1B_3 |
time: 7.59 ms (started: 2025-11-03 21:33:10 +01:00)
In [27]:
problematic_genes = de_df[(de_df['guide'].str.contains('GFI1B')) & \
(de_df['gene'].isin(gene_comp['GFI1B']['notdep_siglarge'])) &\
(de_df['FDR']==0)
]
problematic_genes
Out[27]:
| logFC | logCPM | F | PValue | FDR | gene | guide | |
|---|---|---|---|---|---|---|---|
| 27588 | 4.123429 | 5.680410 | 329.657290 | 0.0 | 0.0 | IL1B | GFI1B_1 |
| 41381 | 3.522624 | 5.914566 | 373.628880 | 0.0 | 0.0 | CLC | GFI1B_2 |
| 41385 | 2.922150 | 7.001393 | 279.075077 | 0.0 | 0.0 | PRSS2 | GFI1B_2 |
| 41386 | 3.651268 | 5.883702 | 296.803656 | 0.0 | 0.0 | CCL4L2 | GFI1B_2 |
| 41466 | 2.184632 | 5.656408 | 100.945521 | 0.0 | 0.0 | PLEK | GFI1B_2 |
| 41482 | 1.992146 | 6.092199 | 86.716443 | 0.0 | 0.0 | CCL3L3 | GFI1B_2 |
| 55169 | -3.216667 | 6.684942 | 97.740010 | 0.0 | 0.0 | REN | GFI1B_3 |
time: 47.5 ms (started: 2025-11-03 21:17:18 +01:00)
In [28]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
model[cg].set_technical_groups(['Sample'])
[INFO] Set technical_group_code with 1 groups based on ['Sample'] [INFO] Set technical_group_code with 1 groups based on ['Sample'] [INFO] Set technical_group_code with 1 groups based on ['Sample'] time: 4.76 ms (started: 2025-11-03 21:17:18 +01:00)
In [31]:
for tg in problematic_genes['gene'].values:
model['GFI1B'].plot_xy_data(
tg,
window=100,
sum_factor_col='sum_factor_new',
show_correction='uncorrected'
)
model['GFI1B'].plot_xy_data(
tg,
window=100,
sum_factor_col='clustered.sum.factor',
show_correction='uncorrected'
)
time: 3.29 s (started: 2025-11-03 21:18:50 +01:00)
In [51]:
for tg in triple_problems:
model['GFI1B'].plot_xy_data(
tg,
window=100,
sum_factor_col='sum_factor_new',
show_correction='uncorrected'
)
time: 1.59 s (started: 2025-11-03 21:28:50 +01:00)
In [ ]: